import torch
import torch.nn as nn
import torch.nn.functional as F


class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, reduction='mean'):
        """
        Focal Loss, ref papre:
            http://openaccess.thecvf.com/content_ICCV_2017/papers/Lin_Focal_Loss_for_ICCV_2017_paper.pdf

        :param alpha: float, weight, detail ref paper above
        :param gamma: float, detail ref paper above
        :param reduction: str, 'mean'|'sum'|'none', reduction type
        """
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, output, target):
        """
        :param output: Tensor, shape (batch_size, class_nums, )
        :param target: Tensor, shape (batch_size, )
        :return: Tensor
        """
        if output.dim != target.dim:
            # convert target to onehot
            target = torch.zeros_like(output).scatter_(1, target.unsqueeze(1), 1)
        # convert output to presudo probability
        p = output.sigmoid()
        a = target * self.alpha + (1 - target) * (1 - self.alpha)
        pt = 1 - (target * p + (1 - target) * (1 - p))
        focal_weight = (a * pt.pow(self.gamma)).detach()

        focal_loss = F.binary_cross_entropy(p, target, weight=focal_weight, reduction=self.reduction)
        return focal_loss
